Skip to content

Add tensor parameter containers#75

Open
fnattino wants to merge 14 commits intomainfrom
tensor-param-template
Open

Add tensor parameter containers#75
fnattino wants to merge 14 commits intomainfrom
tensor-param-template

Conversation

@fnattino
Copy link
Collaborator

@fnattino fnattino commented Jan 15, 2026

relates #25

@fnattino
Copy link
Collaborator Author

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:
    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

@SCiarella
Copy link
Collaborator

Thanks @fnattino, this looks fantastic 🚀

I really like the template to automatically broadcast to the correct shape and device at the beginning, because right now we are doing it quite a lot of times in the integration loops.

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

@SarahAlidoost
Copy link
Collaborator

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:

    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

It is Awesome! 🥇 Thanks. I like how things get simpler and cleaner. Just one comment about naming, see above.

@fnattino
Copy link
Collaborator Author

fnattino commented Jan 19, 2026

Thank you @SCiarella and @SarahAlidoost for the useful feedback!

@SCiarella :

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

Indeed, I think it's a good idea to also add similar containers for states and rates, so all variables are initialized with the correct shape and device!

@SarahAlidoost:

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

My idea was to use Tensor in order to define variables that are expected to be tensors, in a similar fashion in which pcse has pcse.traitlets.Float or pcse.traitlets.Bool for floats and booleans, respectively. Right now, all the variables expeced to be tensors were marked as generic Any. Variables that are defined as Tensor are automatically checked to be of torch.Tensor type or casted into such type via the validate method, so for instance:

import torch
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.traitlets import Tensor

class Parameters(TensorParamTemplate):
    A = Tensor(0.)
    B = Tensor(0, dtype=int)

# Parameters A and B are casted into tensors
params = Parameters(dict(A=0., B=0))

params.A
# tensor(0., dtype=torch.float64)

params.B
# tensor(0)

@sonarqubecloud
Copy link

sonarqubecloud bot commented Feb 6, 2026

tmin = _get_drv(drv.TMIN, self.params.shape, dtype=self.dtype, device=self.device)

# Assimilation is zero before crop emergence (DVS < 0)
dvs_mask = (dvs >= 0).to(dtype=self.dtype)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this needed? Why cannot we leave it as a tensor with dtype bool?

Comment on lines +209 to +211
SLA = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device)
LVAGE = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device)
LV = torch.zeros((self.MAX_DAYS, *params.shape), dtype=self.dtype, device=self.device)
Copy link
Collaborator Author

@fnattino fnattino Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized the choice of having the time dimension as last axis was probably not the best one. Having the time dimension as first axis instead allows for automatic broadcasting, i.e.:

import torch

# This works fine
x = torch.zeros((10, 5, 5)) + torch.ones((5, 5))

# This does not
x = torch.zeros((5, 5, 10)) + torch.ones((5, 5))
# RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 2

So I moved the time axis, which only implies minor changes, and only in this module.

(*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device
)
LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device)
SLA[..., 0] = params.SLATB(DVS).to(dtype=self.dtype, device=self.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of the afgen should already have correct dtype and device, correct? So no need of the .to(...) here?

# If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask
# A mask (0 if DVS < 0, 1 if DVS >= 0)
DVS = torch.as_tensor(k["DVS"], dtype=self.dtype, device=self.device)
dvs_mask = (DVS >= 0).to(dtype=self.dtype).to(device=self.device)
Copy link
Collaborator Author

@fnattino fnattino Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, I think we can leave this as a tensor with bool dtype? So we don't have to convert it to bool at lines 434, 441, and 451 below?

# in DALV.
# Note that the actual leaf death is imposed on the array LV during the
# state integration step.
tSPAN = _broadcast_to(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no more needed after I moved the time axis as leading dimension (automatic broadcasting)

# is used.
span_mask = hard_mask.detach() + soft_mask - soft_mask.detach()
else:
span_mask = (s.LVAGE > tSPAN).to(dtype=self.dtype)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversion to dtype should not be needed here?

self.params_shape = _get_params_shape(self.params)

DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device)
DVS = _broadcast_to(self.kiosk["DVS"], self.params.shape)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, here I make sure the input (i.e. DVS) is broadcasted to the correct shape, instead of all outputs. It's only needed in test environments, where DVS is "injected" as external parameter.

Comment on lines +71 to +87
def _on_CROP_START(
self, day, crop_name=None, variety_name=None, crop_start_type=None, crop_end_type=None
):
"""Starts the crop."""
self.logger.debug(f"Received signal 'CROP_START' on day {day}")

if self.crop is not None:
raise RuntimeError(
"A CROP_START signal was received while self.cropsimulation still holds a valid "
"cropsimulation object. It looks like you forgot to send a CROP_FINISH signal with "
"option crop_delete=True"
)

self.parameterprovider.set_active_crop(
crop_name, variety_name, crop_start_type, crop_end_type
)
self.crop = self.mconf.CROP(day, self.kiosk, self.parameterprovider, shape=self._shape)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to redefine this function to pass the shape to the crop model.

self.mconf = config

self.parameterprovider = parameterprovider
self._shape = _get_params_shape(self.parameterprovider)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the shape is inferred from the parameters only. But we might infer it also from the weather data later on.

Comment on lines -633 to -639
if x.dim() == 0:
# For 0-d tensors, we simply broadcast to the given shape
return torch.broadcast_to(x, shape)
# The given shape should match x in all but the last axis, which represents
# the dimension along which the time integration is carried out.
# We first append an axis to x, then expand to the given shape
return x.unsqueeze(-1).expand(shape)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last part of this function was only needed when we had to broadcast the tensors in leaf dynamics with the additional time axis. This is not needed anymore, and the function is much "cleaner" now.

from diffwofost.physical_models.utils import prepare_engine_input
from . import phy_data_folder

config = Configuration(CROP=DVS_Phenology)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the Engine implemented in diffwofost passes the shape to the crop model - so the original phenology model from PCSE cannot be used here anymore.

@fnattino fnattino marked this pull request as ready for review February 6, 2026 19:09
@fnattino
Copy link
Collaborator Author

fnattino commented Feb 6, 2026

Hi @SarahAlidoost @SCiarella this is now ready to be reviewed. I know it's quite some changes, so I tried to leave as many comments as possible to facilitate the review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants